Skip to content

Gemma4 HeteroTP support for NIXL KV connector#41169

Closed
ZhanqiuHu wants to merge 49 commits into
vllm-project:mainfrom
ZhanqiuHu:gemma4-heterotp-support
Closed

Gemma4 HeteroTP support for NIXL KV connector#41169
ZhanqiuHu wants to merge 49 commits into
vllm-project:mainfrom
ZhanqiuHu:gemma4-heterotp-support

Conversation

@ZhanqiuHu

Copy link
Copy Markdown
Contributor

Depends on #40731

@mergify

mergify Bot commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 28, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the NIXL connector to use a plan-based transfer architecture, introducing EngineTransferPlan to handle complex transfer geometries for hybrid SSM+Attention and HeteroTP models. While this architecture simplifies the transfer hot path, several critical bugs were identified: the splitting logic incorrectly assumes uniform parameters across all attention groups, and an incorrect boundary parameter breaks Mamba state transfers. Additionally, the block trimming logic must be updated to avoid corrupting Mamba state, and plan generation needs to be more robust when encountering homogeneous remote engines to prevent assertion failures.

Comment on lines +1052 to +1056
remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group or ())
assert len(remote_tpb) == len(self._group_kinds), (
f"Remote tokens_per_block_per_group length "
f"{len(remote_tpb)} != {len(self._group_kinds)} groups"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion len(remote_tpb) == len(self._group_kinds) will fail when connecting to a remote engine that is homogeneous (e.g., an older vLLM version or a different model configuration). In such cases, nixl_agent_meta.tokens_per_block_per_group is None, causing remote_tpb to be an empty tuple, while self._group_kinds is always populated. For homogeneous remotes, remote_tpb should default to a tuple of nixl_agent_meta.block_size repeated for each group.

Suggested change
remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group or ())
assert len(remote_tpb) == len(self._group_kinds), (
f"Remote tokens_per_block_per_group length "
f"{len(remote_tpb)} != {len(self._group_kinds)} groups"
)
remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group) if nixl_agent_meta.tokens_per_block_per_group is not None else (nixl_agent_meta.block_size,) * len(self._group_kinds)
assert len(remote_tpb) == len(self._group_kinds), (
f"Remote tokens_per_block_per_group length "
f"{len(remote_tpb)} != {len(self._group_kinds)} groups"
)

Comment on lines +1101 to +1106
fa_slot = fa_slot_map.get(p_rank, 0)

handle: list[tuple[int, int, int]] = []
for j, (addr, local_len, dev) in enumerate(src_blocks_data):
if j < num_fa_descs:
chunk = local_len // fa_num_splits

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The splitting logic in build_local_splits_from_plan incorrectly assumes that all attention groups share the same split count and slot mapping as the first group (plan.source_ranks_per_group[0] and plan.rank_to_attention_slot[0]). In HeteroTP models like Gemma4, different attention groups (e.g., SWA vs. FA) can have different numbers of KV heads, which may result in different numbers of remote source ranks (due to GQA deduplication) and different head-to-slot assignments. Using group 0's parameters for all attention descriptors will lead to incorrect memory slicing and data corruption for other groups.

Comment on lines +1166 to +1170
for handle_data in build_local_splits_from_plan(
plan,
self.src_blocks_data,
self.num_descs,
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Passing self.num_descs as the num_fa_descs boundary to build_local_splits_from_plan is incorrect if self.num_descs represents the total number of descriptors (FA + SSM). This causes the logic in transfer_plan.py to treat all descriptors as attention descriptors (since j < num_fa_descs will always be true), and has_ssm_descs will be incorrectly evaluated as False. This will break the 3-read transfer logic for Mamba state.

Comment on lines +1899 to +1913
# Partial prefix cache hit: trim to the shorter of local/remote.
# After ReadSpec construction, local descriptor IDs and remote
# block IDs should already have matched lengths per group
# (gather-read pairing ensures this). Trim from the head to
# keep the tail (newest blocks).
remote_block_ids = list(remote_block_ids)
for i, remote_group in enumerate(remote_block_ids):
num_remote_blocks = len(remote_group)
num_local_blocks = len(local_block_ids[i])
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
remote_block_ids[i] = remote_group[-num_local_blocks:]
local_block_ids = list(local_block_ids)
for i in range(len(remote_block_ids)):
n_local = len(local_block_ids[i])
n_remote = len(remote_block_ids[i])
n = min(n_local, n_remote)
if n_local > n:
local_block_ids[i] = local_block_ids[i][-n:]
if n_remote > n:
remote_block_ids[i] = remote_block_ids[i][-n:]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new trimming logic in _read_blocks removes the Mamba-specific check that previously prevented trimming for Mamba groups. Mamba blocks represent full state (conv + SSM) rather than incremental per-token data. Trimming these blocks based on a length mismatch between local and remote requests is likely to result in state corruption. The logic should continue to skip trimming for groups where is_ssm is true.

Suggested change
# Partial prefix cache hit: trim to the shorter of local/remote.
# After ReadSpec construction, local descriptor IDs and remote
# block IDs should already have matched lengths per group
# (gather-read pairing ensures this). Trim from the head to
# keep the tail (newest blocks).
remote_block_ids = list(remote_block_ids)
for i, remote_group in enumerate(remote_block_ids):
num_remote_blocks = len(remote_group)
num_local_blocks = len(local_block_ids[i])
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
remote_block_ids[i] = remote_group[-num_local_blocks:]
local_block_ids = list(local_block_ids)
for i in range(len(remote_block_ids)):
n_local = len(local_block_ids[i])
n_remote = len(remote_block_ids[i])
n = min(n_local, n_remote)
if n_local > n:
local_block_ids[i] = local_block_ids[i][-n:]
if n_remote > n:
remote_block_ids[i] = remote_block_ids[i][-n:]
for i in range(len(remote_block_ids)):
if self._group_kinds[i].is_ssm:
continue
n_local = len(local_block_ids[i])
n_remote = len(remote_block_ids[i])
n = min(n_local, n_remote)
if n_local > n:
local_block_ids[i] = local_block_ids[i][-n:]
if n_remote > n:
remote_block_ids[i] = remote_block_ids[i][-n:]

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…delBlockTransferPolicy

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…cal dict

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
… and remote_tp_size

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…fo, remove defaults, regroup args

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…mote and physical_blocks_per_logical, regroup

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…npack read spec loop

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…omputation

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…blocks_per_logical removal

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_policy and use_mla, set num_descs

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…nsfer_info

Add physical_blocks_per_logical to TransferTopology and pass
transfer_topo directly to build_engine_transfer_info, reducing the
method's parameter count from 10 to 6.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…11→4)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_id (7→4)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…to_kernel_block_ids

Move model-specific block ID expansion and trimming logic out of
worker.py hot paths into a unified logical_to_kernel_block_ids()
in transfer_plan.py. Per-request functions now consume only the
pre-computed EngineTransferPlan (model-agnostic); model awareness
is confined to init and plan generation (if/else dense vs mamba).

- Add transfer_plan.py with plan generators, executors, and
  logical_to_kernel_block_ids (per-group physical_per_logical).
- Generate EngineTransferPlan during handshake (generate_dense_plan
  or generate_mamba_plan), stored in _transfer_plans dict.
- Replace _logical_to_remote_kernel_block_ids in _read_blocks_for_req
  with plan-based logical_to_kernel_block_ids (no model branching).
- Make block trimming in _read_blocks unconditional (SSM groups are
  no-op due to shared block table).
- Thin-wrap _logical_to_kernel_block_ids for local expansion,
  delegating to the same unified function.
- Add _conv_decomp to worker init for mamba plan generation.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
The unified logical_to_kernel_block_ids had two bugs:
1. Dense remote: used ratio=1 instead of local kernel ratio
2. Mamba FA remote: used same value for stride and count,
   but old code used remote_ratio as stride, local_ratio as count

Restore the original two-method design:
- _logical_to_kernel_block_ids: local expansion (same as main)
- _logical_to_remote_kernel_block_ids: remote mamba expansion (same as main)

Only difference from main: remote_ratio comes from
plan.remote_physical_blocks_per_logical instead of
self._mamba_phys_ratio[engine_id].

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…ake field

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…nsion

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Per-group TP mapping, sub-descriptor splitting, and block ID
remapping for heterogeneous attention models (SWA + FA groups
with different total_num_kv_heads).

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
When D_TP < P_TP (e.g. 4p2d), local pages are larger than remote pages.
Introduce gather-read: split local blocks into sub-descriptors matching
the remote page size so NIXL can pair them for RDMA.

- Add local_to_remote_page_ratio and remote_blocks_per_local_block to
  EngineTransferPlan; generate_gemma4_plan now handles both split-read
  (2p4d) and gather-read (4p2d) directions
- Add build_fa_local_descs_for_gather_read for local sub-desc registration
- Worker: register gather-read handles, bidirectional block trimming,
  scale local desc-space block count for gather-read
- Fix FullAttentionSpec.merge() and SinkFullAttentionSpec.merge() to
  propagate total_num_kv_heads (was causing AssertionError on Gemma4)
- Add unit tests for gather-read configs (4p2d, 4p1d)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
- Rename sub_desc_index_per_group → remote_desc_offset_per_group
  and _remap_remote_blocks_to_subdesc_ids → _remap_remote_blocks_to_desc_ids
  to eliminate confusing "sub-descriptor" naming throughout.
- Add _pair_gather_group helper with remainder handling for partial
  block fills in gather-read (FA groups).
- Add length-match assertions in both _build_gather_read_specs and
  _read_blocks to catch descriptor/block ID pairing mismatches early.
- Add diagnostic INFO/DEBUG logging for HeteroTP plan generation,
  ReadSpec construction, and _read_blocks transfer details.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@ZhanqiuHu ZhanqiuHu force-pushed the gemma4-heterotp-support branch from ca3b517 to 56d12e7 Compare April 29, 2026 22:08
@mergify mergify Bot removed the needs-rebase label Apr 29, 2026
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@mergify

mergify Bot commented May 23, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant